import torch
from library.device_utils import init_ipex
init_ipex()

from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

from library.original_unet import SampleOutput


def unet_forward_XTI(
    self,
    sample: torch.FloatTensor,
    timestep: Union[torch.Tensor, float, int],
    encoder_hidden_states: torch.Tensor,
    class_labels: Optional[torch.Tensor] = None,
    return_dict: bool = True,
) -> Union[Dict, Tuple]:
    r"""
    Args:
        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
        timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
        encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a dict instead of a plain tuple.

    Returns:
        `SampleOutput` or `tuple`:
        `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
    """
    # By default samples have to be AT least a multiple of the overall upsampling factor.
    # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
    # However, the upsampling interpolation output size can be forced to fit any upsampling size
    # on the fly if necessary.
    # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
    # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
    # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
    default_overall_up_factor = 2**self.num_upsamplers

    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
    # 64で割り切れないときはupsamplerにサイズを伝える
    forward_upsample_size = False
    upsample_size = None

    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
        # logger.info("Forward upsample size to force interpolation output size.")
        forward_upsample_size = True

    # 1. time
    timesteps = timestep
    timesteps = self.handle_unusual_timesteps(sample, timesteps)  # 変な時だけ処理

    t_emb = self.time_proj(timesteps)

    # timesteps does not contain any weights and will always return f32 tensors
    # but time_embedding might actually be running in fp16. so we need to cast here.
    # there might be better ways to encapsulate this.
    # timestepsは重みを含まないので常にfloat32のテンソルを返す
    # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
    # time_projでキャストしておけばいいんじゃね？
    t_emb = t_emb.to(dtype=self.dtype)
    emb = self.time_embedding(t_emb)

    # 2. pre-process
    sample = self.conv_in(sample)

    # 3. down
    down_block_res_samples = (sample,)
    down_i = 0
    for downsample_block in self.down_blocks:
        # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
        # まあこちらのほうがわかりやすいかもしれない
        if downsample_block.has_cross_attention:
            sample, res_samples = downsample_block(
                hidden_states=sample,
                temb=emb,
                encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
            )
            down_i += 2
        else:
            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

        down_block_res_samples += res_samples

    # 4. mid
    sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])

    # 5. up
    up_i = 7
    for i, upsample_block in enumerate(self.up_blocks):
        is_final_block = i == len(self.up_blocks) - 1

        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]  # skip connection

        # if we have not reached the final block and need to forward the upsample size, we do it here
        # 前述のように最後のブロック以外ではupsample_sizeを伝える
        if not is_final_block and forward_upsample_size:
            upsample_size = down_block_res_samples[-1].shape[2:]

        if upsample_block.has_cross_attention:
            sample = upsample_block(
                hidden_states=sample,
                temb=emb,
                res_hidden_states_tuple=res_samples,
                encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
                upsample_size=upsample_size,
            )
            up_i += 3
        else:
            sample = upsample_block(
                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
            )

    # 6. post-process
    sample = self.conv_norm_out(sample)
    sample = self.conv_act(sample)
    sample = self.conv_out(sample)

    if not return_dict:
        return (sample,)

    return SampleOutput(sample=sample)


def downblock_forward_XTI(
    self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
    output_states = ()
    i = 0

    for resnet, attn in zip(self.resnets, self.attentions):
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
            )[0]
        else:
            hidden_states = resnet(hidden_states, temb)
            hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample

        output_states += (hidden_states,)
        i += 1

    if self.downsamplers is not None:
        for downsampler in self.downsamplers:
            hidden_states = downsampler(hidden_states)

        output_states += (hidden_states,)

    return hidden_states, output_states


def upblock_forward_XTI(
    self,
    hidden_states,
    res_hidden_states_tuple,
    temb=None,
    encoder_hidden_states=None,
    upsample_size=None,
):
    i = 0
    for resnet, attn in zip(self.resnets, self.attentions):
        # pop res hidden states
        res_hidden_states = res_hidden_states_tuple[-1]
        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
            )[0]
        else:
            hidden_states = resnet(hidden_states, temb)
            hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample

        i += 1

    if self.upsamplers is not None:
        for upsampler in self.upsamplers:
            hidden_states = upsampler(hidden_states, upsample_size)

    return hidden_states
